unset LD_LIBRARY_PATH
CONFIG=$1  # 配置文件作为第一个参数

NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

# 硬编码的默认设备列表（可修改这里指定默认GPU）
DEFAULT_CUDA_VISIBLE_DEVICES="2,3,4,5,6,7"

# 设置 CUDA_VISIBLE_DEVICES
if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
    # 如果没有设置，使用硬编码的默认值
    export CUDA_VISIBLE_DEVICES=$DEFAULT_CUDA_VISIBLE_DEVICES
fi

# 计算可见GPU数量
IFS=',' read -ra devices <<< "$CUDA_VISIBLE_DEVICES"
GPUS=${#devices[@]}

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
    --nnodes=$NNODES \
    --node_rank=$NODE_RANK \
    --master_addr=$MASTER_ADDR \
    --nproc_per_node=$GPUS \
    --master_port=$PORT \
    $(dirname "$0")/train.py \
    $CONFIG \
    --seed 0 \
    --launcher pytorch "${@:2}"  # 注意这里改为 ${@:2}